Skip to content

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015

Draft
ricardoV94 wants to merge 7 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion
Draft

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015
ricardoV94 wants to merge 7 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 29, 2026

Summary

Introduce IndexedElemwise, an OpFromGraph that wraps AdvancedSubtensor + Elemwise + AdvancedIncSubtensor subgraphs so the Numba backend can generate a single loop with indirect indexing, avoiding materializing AvancedSubtensor input arrays, and writing directly on the output buffer, doing the job of AdvancedIncSubtensor in the same loop, without having to loop again through the intermediate elemwise output

Commit 1 fuses indexed reads (AdvancedSubtensor1 on inputs).
Commit 2 fuses indexed updates (AdvancedIncSubtensor1 on outputs).
Commit 3 extends to AdvancedSubtensor inputs, on arbitrary (1d) indexed (consecutive) axes

Motivation

In hierarchical models with mu = beta[group_idx] * x + ..., the logp+gradient graph combines indexed reads and indexed updates in the same Elemwise (the forward expands group-level parameters via advanced subtensor, and the gradient accumulates back into the source via advanced inc subtensor).

A simler

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

numba_mode = get_mode("NUMBA")
numba_mode_before = numba_mode.excluding("fuse_indexed_elemwise")

x = pt.vector("x")
idx = pt.vector("idx", dtype=int)
value = pt.vector("value")

y = pt.zeros(100)
out = ((x[idx] - value) ** 2).sum()
grad_wrt_x = pt.grad(out, x)
fn_before = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode_before, trust_input=True)
fn_before.dprint(print_op_info=True, print_destroy_map=True)
# Sum{axes=None} [id A] 5
#  └─ Composite{...}.0 [id B] d={0: [0]} 1
#     ├─ AdvancedSubtensor1 [id C] 0
#     │  ├─ x [id D]
#     │  └─ idx [id E]
#     └─ value [id F]
# AdvancedIncSubtensor1{inplace,inc} [id G] d={0: [0]} 4
#  ├─ Alloc [id H] 3
#  │  ├─ [0.] [id I]
#  │  └─ Shape_i{0} [id J] 2
#  │     └─ x [id D]
#  ├─ Composite{...}.1 [id B] d={0: [0]} 1
#  │  └─ ···
#  └─ idx [id E]

# Inner graphs:

# Composite{...} [id B] d={0: [0]}
#  ← sqr [id K] 'o0'
#     └─ sub [id L] 't5'
#        ├─ i0 [id M]
#        └─ i1 [id N]
#  ← mul [id O] 'o1'
#     ├─ 2.0 [id P]
#     └─ sub [id L] 't5'
#        └─ ···

fn = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode, trust_input=True)
fn.dprint(print_op_info=True, print_destroy_map=True)

# Sum{axes=None} [id A] 3
#  └─ IndexedElemwise{Composite{...}}.0 [id B] d={1: [3]} 2
#     ├─ x [id C] (indexed read (idx_0))
#     ├─ value [id D]
#     ├─ idx [id E] (idx_0)
#     └─ Alloc [id F] 1 (buf_0)
#        ├─ [0.] [id G]
#        └─ Shape_i{0} [id H] 0
#           └─ x [id C]
# IndexedElemwise{Composite{...}}.1 [id B] d={1: [3]} 2 (indexed inc (buf_0, idx_0))
#  └─ ···

# Inner graphs:

# IndexedElemwise{Composite{...}} [id B] d={1: [3]}
#  ← Composite{...}.0 [id I]
#     ├─ AdvancedSubtensor1 [id J]
#     │  ├─ *0-<Vector(float64, shape=(?,))> [id K]
#     │  └─ *2-<Vector(int64, shape=(?,))> [id L]
#     └─ *1-<Vector(float64, shape=(?,))> [id M]
#  ← AdvancedIncSubtensor1{inplace,inc} [id N] d={0: [0]}
#     ├─ *3-<Vector(float64, shape=(?,))> [id O]
#     ├─ Composite{...}.1 [id I]
#     │  └─ ···
#     └─ *2-<Vector(int64, shape=(?,))> [id L]

# Composite{...} [id I]
#  ← sqr [id P] 'o0'
#     └─ sub [id Q] 't0'
#        ├─ i0 [id R]
#        └─ i1 [id S]
#  ← mul [id T] 'o1'
#     ├─ 2.0 [id U]
#     └─ sub [id Q] 't0'
#        └─ ···

x_test = np.arange(15, dtype="float64")
idx_test = np.random.randint(15, size=(10_000,))
value_test = np.random.normal(size=idx_test.shape)

logp_before, dlogp_before = fn_before(x_test, value_test, idx_test)
logp, dlogp = fn(x_test, value_test, idx_test)
np.testing.assert_allclose(logp_before, logp)
np.testing.assert_allclose(dlogp_before, dlogp)

%timeit fn_before(x_test, value_test, idx_test)  # 29.4 μs ± 2.57 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fn(x_test, value_test, idx_test)  # 13.8 μs ± 136 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Next step would be to also fuse the sum directly on the elemwise, so we end up with a single loop over the data. This is important as the sum can easily break our fusion, as we don't fuse if the elemwise output is needed elsewhere (like in a sum).

@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 6d875d8 to 0ad6e2e Compare March 29, 2026 18:14
@ricardoV94 ricardoV94 changed the title Numba: fuse AdvancedSubtensor+Elemwise Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor Mar 29, 2026
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 8 times, most recently from 9394746 to 41869a4 Compare April 1, 2026 21:35
Extend FusionOptimizer to merge independent subgraphs that share
inputs but have no producer-consumer edge (siblings like f(x) and g(x)).
The eager expansion only walks producer-consumer edges, missing these.

Also extract InplaceGraphOptimizer.try_inplace_on_node helper and
_insert_sorted_subgraph to deduplicate insertion-point logic.
Fuse single-client AdvancedSubtensor1 nodes into Elemwise loops,
replacing indirect array reads with a single iteration loop that
uses index arrays for input access.

Before (2 nodes):
  temp = x[idx]                    # AdvancedSubtensor1, shape (919,)
  result = temp + y                # Elemwise

After (1 fused loop, x is read directly via idx):
  for k in range(919):
      result[k] = x[idx[k]] + y[k]

- Introduce IndexedElemwise Op (in rewriting/indexed_elemwise.py)
- Add FuseIndexedElemwise rewrite with SequenceDB
- Event _vectorized intrinsic to support indirect readings and writings
- Add op_debug_information for dprint(print_op_info=True)
Extend the IndexedElemwise fusion to also absorb
AdvancedIncSubtensor1 (indexed set/inc) on the output side.

Before (3 nodes):
  temp = Elemwise(x[idx], y)               # shape (919,)
  result = IncSubtensor(target, temp, idx)  # target shape (85,)

After (1 fused loop, target is an input):
  for k in range(919):
      target[idx[k]] += scalar_fn(x[idx[k]], y[k])

- FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers
- Reject fusion when val broadcasts against target's non-indexed axes
- store_core_outputs supports inc mode via o[...] += val
- Inner fgraph always uses inplace IncSubtensor
- op_debug_information shows buf_N / idx_N linkage
Support AdvancedSubtensor on any axis (not just axis 0) and multi-index
patterns like x[idx_row, idx_col] where multiple 1D index arrays address
consecutive source axes.  Generalize writes (AdvancedIncSubtensor) to
match.

Reads:
- Add undo_take_dimshuffle_for_fusion pre-fusion rewrite
- _get_indexed_read_info handles AdvancedSubtensor with consecutive
  tensor indices, full-slice prefix/suffix
- Reject boolean indices and non-consecutive advanced indices

Writes:
- _get_indexed_update_info mirrors _get_indexed_read_info for
  AdvancedIncSubtensor
- find_indexed_update_consumers detects both AdvancedIncSubtensor1
  and AdvancedIncSubtensor
- Broadcast guard generalized for non-axis-0 indexed axes
- Scatter construction supports AdvancedIncSubtensor (inplace)

Dispatch + codegen:
- indexed_inputs encoding: ((positions, axis, idx_bc), ...)
- input_read_spec uses tuple of (idx_k, axis) pairs per input
- n_index_loop_dims = max(idx.ndim for group)
Support multidimensional (e.g. 2D matrix) and 0-d integer indices in
IndexedElemwise fusion, for both reads and writes.

ND indices:
- Add undo_take_reshape_for_fusion: undoes the Reshape+flatten pattern
  that transform_take applies for ND indices, recovering the original
  AdvancedSubtensor(source, mat_idx) form for fusion.
  Handles both axis=0 and axis>0 (with DimShuffle wrapping).
- idx_load_axes: tuple of tuples, each index array loads from idx_ndim
  loop counters

0-d indices:
- Accept 0-d tensor indices (e.g. x[scalar_idx, vec_idx]) which are
  valid AdvancedSubtensor inputs that broadcast with higher-dim indices.
Allow the written buffer in IndexedElemwise to be larger than scalar
in the core loop. When the Elemwise computes a scalar per index
iteration and the target has extra non-indexed dims, the scalar
broadcasts into those dims via o[...] = scalar.

Key changes:
- Generalize _getitem_0d_ellipsis to any ndim so o[...] += scalar
  works for 1-d+ core output arrays
- Relax the broadcast guard to allow val with fewer non-indexed dims
  when the indexed axes are rightmost (no trailing non-indexed dims)
- Transpose the target buffer in the fusion rewrite when non-indexed
  dims precede the indexed axis, so core dims are trailing (gufunc
  convention)
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 0acf078 to c1553dc Compare April 2, 2026 20:36
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch from c1553dc to e286e58 Compare April 3, 2026 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant